Skip to content

Conversation

@topperc
Copy link
Collaborator

@topperc topperc commented Apr 26, 2025

Create an extra overload of getStore that can handle of the 3 types of stores. This is similar to how getLoad/getExtLoad/getIndexLoad is structure.

…e, and getIndexedStore.^

Create an extra overload of getStore than can handle of the 3
types of stores. This is similar to how getLoad/getExtLoad/getIndexLoad
is structure.
@llvmbot llvmbot added the llvm:SelectionDAG SelectionDAGISel as well label Apr 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 26, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Craig Topper (topperc)

Changes

Create an extra overload of getStore that can handle of the 3 types of stores. This is similar to how getLoad/getExtLoad/getIndexLoad is structure.


Full diff: https://github.com/llvm/llvm-project/pull/137435.diff

2 Files Affected:

  • (modified) llvm/include/llvm/CodeGen/SelectionDAG.h (+3)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp (+38-71)
diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h
index c183149b0863a..ba11ddbb5b731 100644
--- a/llvm/include/llvm/CodeGen/SelectionDAG.h
+++ b/llvm/include/llvm/CodeGen/SelectionDAG.h
@@ -1466,6 +1466,9 @@ class SelectionDAG {
                         SDValue Ptr, EVT SVT, MachineMemOperand *MMO);
   SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base,
                           SDValue Offset, ISD::MemIndexedMode AM);
+  SDValue getStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr,
+                   SDValue Offset, EVT SVT, MachineMemOperand *MMO,
+                   ISD::MemIndexedMode AM, bool IsTruncating = false);
 
   SDValue getLoadVP(ISD::MemIndexedMode AM, ISD::LoadExtType ExtType, EVT VT,
                     const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Offset,
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 9da2ba04f77cb..8cc434a49afa2 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -9439,17 +9439,42 @@ SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
 
 SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
                                SDValue Ptr, MachineMemOperand *MMO) {
-  assert(Chain.getValueType() == MVT::Other &&
-        "Invalid chain type");
-  EVT VT = Val.getValueType();
-  SDVTList VTs = getVTList(MVT::Other);
   SDValue Undef = getUNDEF(Ptr.getValueType());
-  SDValue Ops[] = { Chain, Val, Ptr, Undef };
+  return getStore(Chain, dl, Val, Ptr, Undef, Val.getValueType(), MMO,
+                  ISD::UNINDEXED);
+}
+
+SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
+                               SDValue Ptr, SDValue Offset, EVT SVT,
+                               MachineMemOperand *MMO, ISD::MemIndexedMode AM,
+                               bool IsTruncating) {
+  assert(Chain.getValueType() == MVT::Other && "Invalid chain type");
+  EVT VT = Val.getValueType();
+  if (VT == SVT) {
+    IsTruncating = false;
+  } else if (!IsTruncating) {
+    assert(VT == SVT && "No-truncating store from different memory type!");
+  } else {
+    assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
+           "Should only be a truncating store, not extending!");
+    assert(VT.isInteger() == SVT.isInteger() && "Can't do FP-INT conversion!");
+    assert(VT.isVector() == SVT.isVector() &&
+           "Cannot use trunc store to convert to or from a vector!");
+    assert((!VT.isVector() ||
+            VT.getVectorElementCount() == SVT.getVectorElementCount()) &&
+           "Cannot use trunc store to change the number of vector elements!");
+  }
+
+  bool Indexed = AM != ISD::UNINDEXED;
+  assert((Indexed || Offset.isUndef()) && "Unindexed store with an offset!");
+  SDVTList VTs = Indexed ? getVTList(Ptr.getValueType(), MVT::Other)
+                         : getVTList(MVT::Other);
+  SDValue Ops[] = {Chain, Val, Ptr, Offset};
   FoldingSetNodeID ID;
   AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
-  ID.AddInteger(VT.getRawBits());
+  ID.AddInteger(SVT.getRawBits());
   ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
-      dl.getIROrder(), VTs, ISD::UNINDEXED, false, VT, MMO));
+      dl.getIROrder(), VTs, AM, IsTruncating, SVT, MMO));
   ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
   ID.AddInteger(MMO->getFlags());
   void *IP = nullptr;
@@ -9457,8 +9482,8 @@ SDValue SelectionDAG::getStore(SDValue Chain, const SDLoc &dl, SDValue Val,
     cast<StoreSDNode>(E)->refineAlignment(MMO);
     return SDValue(E, 0);
   }
-  auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
-                                   ISD::UNINDEXED, false, VT, MMO);
+  auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
+                                   IsTruncating, SVT, MMO);
   createOperands(N, Ops);
 
   CSEMap.InsertNode(N, IP);
@@ -9491,47 +9516,8 @@ SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val,
 SDValue SelectionDAG::getTruncStore(SDValue Chain, const SDLoc &dl, SDValue Val,
                                     SDValue Ptr, EVT SVT,
                                     MachineMemOperand *MMO) {
-  EVT VT = Val.getValueType();
-
-  assert(Chain.getValueType() == MVT::Other &&
-        "Invalid chain type");
-  if (VT == SVT)
-    return getStore(Chain, dl, Val, Ptr, MMO);
-
-  assert(SVT.getScalarType().bitsLT(VT.getScalarType()) &&
-         "Should only be a truncating store, not extending!");
-  assert(VT.isInteger() == SVT.isInteger() &&
-         "Can't do FP-INT conversion!");
-  assert(VT.isVector() == SVT.isVector() &&
-         "Cannot use trunc store to convert to or from a vector!");
-  assert((!VT.isVector() ||
-          VT.getVectorElementCount() == SVT.getVectorElementCount()) &&
-         "Cannot use trunc store to change the number of vector elements!");
-
-  SDVTList VTs = getVTList(MVT::Other);
   SDValue Undef = getUNDEF(Ptr.getValueType());
-  SDValue Ops[] = { Chain, Val, Ptr, Undef };
-  FoldingSetNodeID ID;
-  AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
-  ID.AddInteger(SVT.getRawBits());
-  ID.AddInteger(getSyntheticNodeSubclassData<StoreSDNode>(
-      dl.getIROrder(), VTs, ISD::UNINDEXED, true, SVT, MMO));
-  ID.AddInteger(MMO->getPointerInfo().getAddrSpace());
-  ID.AddInteger(MMO->getFlags());
-  void *IP = nullptr;
-  if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) {
-    cast<StoreSDNode>(E)->refineAlignment(MMO);
-    return SDValue(E, 0);
-  }
-  auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs,
-                                   ISD::UNINDEXED, true, SVT, MMO);
-  createOperands(N, Ops);
-
-  CSEMap.InsertNode(N, IP);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getStore(Chain, dl, Val, Ptr, Undef, SVT, MMO, ISD::UNINDEXED, true);
 }
 
 SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
@@ -9539,28 +9525,9 @@ SDValue SelectionDAG::getIndexedStore(SDValue OrigStore, const SDLoc &dl,
                                       ISD::MemIndexedMode AM) {
   StoreSDNode *ST = cast<StoreSDNode>(OrigStore);
   assert(ST->getOffset().isUndef() && "Store is already a indexed store!");
-  SDVTList VTs = getVTList(Base.getValueType(), MVT::Other);
-  SDValue Ops[] = { ST->getChain(), ST->getValue(), Base, Offset };
-  FoldingSetNodeID ID;
-  AddNodeIDNode(ID, ISD::STORE, VTs, Ops);
-  ID.AddInteger(ST->getMemoryVT().getRawBits());
-  ID.AddInteger(ST->getRawSubclassData());
-  ID.AddInteger(ST->getPointerInfo().getAddrSpace());
-  ID.AddInteger(ST->getMemOperand()->getFlags());
-  void *IP = nullptr;
-  if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP))
-    return SDValue(E, 0);
-
-  auto *N = newSDNode<StoreSDNode>(dl.getIROrder(), dl.getDebugLoc(), VTs, AM,
-                                   ST->isTruncatingStore(), ST->getMemoryVT(),
-                                   ST->getMemOperand());
-  createOperands(N, Ops);
-
-  CSEMap.InsertNode(N, IP);
-  InsertNode(N);
-  SDValue V(N, 0);
-  NewSDValueDbgMsg(V, "Creating new node: ", this);
-  return V;
+  return getStore(ST->getChain(), dl, ST->getValue(), Base, Offset,
+                  ST->getMemoryVT(), ST->getMemOperand(), AM,
+                  ST->isTruncatingStore());
 }
 
 SDValue SelectionDAG::getLoadVP(

Copy link
Contributor

@s-barannikov s-barannikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@topperc
Copy link
Collaborator Author

topperc commented Apr 26, 2025

I was really hoping this would allow us to cleanup this code and remove the setTruncatingStore method from StoreSDNode.

DAG.UpdateNodeOperands(N, N->getOperand(0), Const64, N->getOperand(2),
, but it causes a test to fail.

It looks like the UpdateNodeOperands doesn't cause DAGCombine to revisit the node. Using getStore does revisit the node. Then visitSTORE does a DemandedBits modification to a constant and creates a worse constant to materialize.

@s-barannikov
Copy link
Contributor

s-barannikov commented Apr 26, 2025

Mutating a node isn't nice. This appears to work:

      auto *ST = cast<StoreSDNode>(N);
      SDValue NewST = DAG.getStore(ST->getChain(), dl, Const64,
                                   ST->getBasePtr(), ST->getOffset(), MemVT,
                                   ST->getMemOperand(), ST->getAddressingMode(),
                                   /*IsTruncating=*/true);
      return ST->isUnindexed()
                 ? DCI.CombineTo(N, NewST, /*AddTo=*/false)
                 : DCI.CombineTo(N, NewST, NewST.getValue(1), /*AddTo=*/false);

@topperc
Copy link
Collaborator Author

topperc commented Apr 28, 2025

Mutating a node isn't nice. This appears to work:

      auto *ST = cast<StoreSDNode>(N);
      SDValue NewST = DAG.getStore(ST->getChain(), dl, Const64,
                                   ST->getBasePtr(), ST->getOffset(), MemVT,
                                   ST->getMemOperand(), ST->getAddressingMode(),
                                   /*IsTruncating=*/true);
      return ST->isUnindexed()
                 ? DCI.CombineTo(N, NewST, /*AddTo=*/false)
                 : DCI.CombineTo(N, NewST, NewST.getValue(1), /*AddTo=*/false);

Thanks I'll try that.

@topperc topperc merged commit e17f07c into llvm:main Apr 28, 2025
13 checks passed
@topperc topperc deleted the pr/getstore branch April 28, 2025 05:32
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
…e, and getIndexedStore. (llvm#137435)

Create an extra overload of getStore that can handle of the 3 types of
stores. This is similar to how getLoad/getExtLoad/getIndexLoad is
structure.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llvm:SelectionDAG SelectionDAGISel as well

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants